From 8e6299b89886fd24586d45656be19c7253c8362b Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Sat, 4 May 2024 20:36:47 +0200 Subject: [PATCH] feat: add VAE select --- ldm_patched/modules/sd.py | 11 +++++++---- modules/async_worker.py | 3 ++- modules/config.py | 14 +++++++++++++- modules/core.py | 8 +++++--- modules/default_pipeline.py | 19 ++++++++++++------- modules/util.py | 3 +++ webui.py | 11 +++++++---- 7 files changed, 49 insertions(+), 20 deletions(-) diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index e197c39c..61260fd0 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -427,7 +427,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, vae_filename=None): sd = ldm_patched.modules.utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None @@ -462,8 +462,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) - vae_sd = model_config.process_vae_state_dict(vae_sd) + if vae_filename is None: + vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) + else: + vae_sd = ldm_patched.modules.utils.load_torch_file(vae_filename) vae = VAE(sd=vae_sd) if output_clip: @@ -485,7 +488,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) - return (model_patcher, clip, vae, clipvision) + return model_patcher, clip, vae, clipvision def load_unet_state_dict(sd): #load unet in diffusers format diff --git a/modules/async_worker.py b/modules/async_worker.py index d8a1e072..f559806a 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -166,6 +166,7 @@ def worker(): adaptive_cfg = args.pop() sampler_name = args.pop() scheduler_name = args.pop() + vae_name = args.pop() overwrite_step = args.pop() overwrite_switch = args.pop() overwrite_width = args.pop() @@ -428,7 +429,7 @@ def worker(): progressbar(async_task, 3, 'Loading models ...') pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras, base_model_additional_loras=base_model_additional_loras, - use_synthetic_refiner=use_synthetic_refiner) + use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name) progressbar(async_task, 3, 'Processing prompts ...') tasks = [] diff --git a/modules/config.py b/modules/config.py index b81e218a..fdb47157 100644 --- a/modules/config.py +++ b/modules/config.py @@ -189,6 +189,7 @@ paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/check paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True) path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/') path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/') +path_vae = get_dir_or_set_default('path_vae', '../models/vae/') path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/') path_inpaint = get_dir_or_set_default('path_inpaint', '../models/inpaint/') path_controlnet = get_dir_or_set_default('path_controlnet', '../models/controlnet/') @@ -346,6 +347,11 @@ default_scheduler = get_config_item_or_set_default( default_value='karras', validator=lambda x: x in modules.flags.scheduler_list ) +default_vae = get_config_item_or_set_default( + key='default_vae', + default_value='None', + validator=lambda x: isinstance(x, str) +) default_styles = get_config_item_or_set_default( key='default_styles', default_value=[ @@ -535,6 +541,7 @@ with open(config_example_path, "w", encoding="utf-8") as json_file: model_filenames = [] lora_filenames = [] +vae_filenames = [] wildcard_filenames = [] sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors' @@ -546,15 +553,20 @@ def get_model_filenames(folder_paths, extensions=None, name_filter=None): if extensions is None: extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] files = [] + + if not isinstance(folder_paths, list): + folder_paths = [folder_paths] for folder in folder_paths: files += get_files_from_folder(folder, extensions, name_filter) + return files def update_files(): - global model_filenames, lora_filenames, wildcard_filenames, available_presets + global model_filenames, lora_filenames, vae_filenames, wildcard_filenames, available_presets model_filenames = get_model_filenames(paths_checkpoints) lora_filenames = get_model_filenames(paths_loras) + vae_filenames = get_model_filenames(path_vae) wildcard_filenames = get_files_from_folder(path_wildcards, ['.txt']) available_presets = get_presets() return diff --git a/modules/core.py b/modules/core.py index 38ee8e8d..c1fa8976 100644 --- a/modules/core.py +++ b/modules/core.py @@ -35,12 +35,13 @@ opModelSamplingDiscrete = ModelSamplingDiscrete() class StableDiffusionModel: - def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None): + def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None, vae_filename=None): self.unet = unet self.vae = vae self.clip = clip self.clip_vision = clip_vision self.filename = filename + self.vae_filename = filename self.unet_with_lora = unet self.clip_with_lora = clip self.visited_loras = '' @@ -142,8 +143,9 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per @torch.no_grad() @torch.inference_mode() -def load_model(ckpt_filename): - unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings) +def load_model(ckpt_filename, vae_filename=None): + unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings, + vae_filename=vae_filename) return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 190601ec..b74faf46 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -58,16 +58,20 @@ def assert_model_integrity(): @torch.no_grad() @torch.inference_mode() -def refresh_base_model(name): +def refresh_base_model(name, vae_name=None): global model_base filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) - if model_base.filename == filename: + vae_filename = None + if vae_name is not None and vae_name != 'None': + vae_filename = get_file_from_folder_list(vae_name, modules.config.path_vae) + + if model_base.filename == filename and model_base.vae_filename == vae_filename: return model_base = core.StableDiffusionModel() - model_base = core.load_model(filename) + model_base = core.load_model(filename, vae_filename) print(f'Base model loaded: {model_base.filename}') return @@ -216,7 +220,7 @@ def prepare_text_encoder(async_call=True): @torch.no_grad() @torch.inference_mode() def refresh_everything(refiner_model_name, base_model_name, loras, - base_model_additional_loras=None, use_synthetic_refiner=False): + base_model_additional_loras=None, use_synthetic_refiner=False, vae_name=None): global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion final_unet = None @@ -227,11 +231,11 @@ def refresh_everything(refiner_model_name, base_model_name, loras, if use_synthetic_refiner and refiner_model_name == 'None': print('Synthetic Refiner Activated') - refresh_base_model(base_model_name) + refresh_base_model(base_model_name, vae_name) synthesize_refiner_model() else: refresh_refiner_model(refiner_model_name) - refresh_base_model(base_model_name) + refresh_base_model(base_model_name, vae_name) refresh_loras(loras, base_model_additional_loras=base_model_additional_loras) assert_model_integrity() @@ -254,7 +258,8 @@ def refresh_everything(refiner_model_name, base_model_name, loras, refresh_everything( refiner_model_name=modules.config.default_refiner_model_name, base_model_name=modules.config.default_base_model_name, - loras=get_enabled_loras(modules.config.default_loras) + loras=get_enabled_loras(modules.config.default_loras), + vae_name=modules.config.default_vae, ) diff --git a/modules/util.py b/modules/util.py index 9e0fb294..d2feecb6 100644 --- a/modules/util.py +++ b/modules/util.py @@ -371,6 +371,9 @@ def is_json(data: str) -> bool: def get_file_from_folder_list(name, folders): + if not isinstance(folders, list): + folders = [folders] + for folder in folders: filename = os.path.abspath(os.path.realpath(os.path.join(folder, name))) if os.path.isfile(filename): diff --git a/webui.py b/webui.py index 98780bff..ab52a63f 100644 --- a/webui.py +++ b/webui.py @@ -406,6 +406,8 @@ with shared.gradio_root: value=modules.config.default_sampler) scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list, value=modules.config.default_scheduler) + vae_name = gr.Dropdown(label='VAE', choices=['None'] + modules.config.vae_filenames, + value=modules.config.default_vae, show_label=True) generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch', info='(Experimental) This may cause performance problems on some computers and certain internet conditions.', @@ -528,6 +530,7 @@ with shared.gradio_root: modules.config.update_files() results = [gr.update(choices=modules.config.model_filenames)] results += [gr.update(choices=['None'] + modules.config.model_filenames)] + results += [gr.update(choices=['None'] + modules.config.vae_filenames)] if not args_manager.args.disable_preset_selection: results += [gr.update(choices=modules.config.available_presets)] for i in range(modules.config.default_max_lora_number): @@ -535,7 +538,7 @@ with shared.gradio_root: gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()] return results - refresh_files_output = [base_model, refiner_model] + refresh_files_output = [base_model, refiner_model, vae_name] if not args_manager.args.disable_preset_selection: refresh_files_output += [preset_selection] refresh_files.click(refresh_files_clicked, [], refresh_files_output + lora_ctrls, @@ -547,8 +550,8 @@ with shared.gradio_root: performance_selection, overwrite_step, overwrite_switch, aspect_ratios_selection, overwrite_width, overwrite_height, guidance_scale, sharpness, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, refiner_swap_method, adaptive_cfg, base_model, - refiner_model, refiner_switch, sampler_name, scheduler_name, seed_random, image_seed, - generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls + refiner_model, refiner_switch, sampler_name, scheduler_name, vae_name, seed_random, + image_seed, generate_button, load_parameter_button] + freeu_ctrls + lora_ctrls if not args_manager.args.disable_preset_selection: def preset_selection_change(preset, is_generating): @@ -634,7 +637,7 @@ with shared.gradio_root: ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image] ctrls += [disable_preview, disable_intermediate_results, disable_seed_increment] ctrls += [adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg] - ctrls += [sampler_name, scheduler_name] + ctrls += [sampler_name, scheduler_name, vae_name] ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength] ctrls += [overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint] ctrls += [debugging_cn_preprocessor, skipping_cn_preprocessor, canny_low_threshold, canny_high_threshold]